<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>tinygrad has WebGPU</title>
    <style>
        /* General Reset */
        * {
            margin: 0;
            padding: 0;
            box-sizing: border-box;
        }

        body {
            font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
            background: #f4f7fb;
            color: #333;
            display: flex;
            justify-content: center;
            align-items: center;
            min-height: 100vh;
            padding: 20px;
            flex-direction: column;
            text-align: center;
        }

        h1 {
            font-size: 2.2rem;
            margin-bottom: 20px;
            color: #4A90E2;
        }

        #wgpuError {
            color: red;
            font-size: 1.2rem;
            margin-top: 20px;
            display: none;
        }

        #sdTitle {
            font-size: 1.5rem;
            margin-bottom: 30px;
        }

        #mybox {
            background: #ffffff;
            padding: 15px;
            border-radius: 10px;
            box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
            width: 120%;
            max-width: 550px;
            margin-bottom: 10px;
        }

        input[type="text"] {
            width: 100%;
            padding: 12px;
            margin-bottom: 20px;
            font-size: 1rem;
            border: 1px solid #ccc;
            border-radius: 8px;
            outline: none;
            transition: all 0.3s ease;
        }

        input[type="text"]:focus {
            border-color: #4A90E2;
        }

        label {
            display: flex;
            justify-content: space-between;
            font-size: 1rem;
            margin-bottom: 15px;
            align-items: center;
        }

        input[type="range"] {
            width: 100%;
            margin-left: 10px;
            -webkit-appearance: none;
            appearance: none;
            height: 8px;
            border-radius: 4px;
            background: #ddd;
            outline: none;
            transition: background 0.3s ease;
        }

        input[type="range"]:focus {
            background: #4A90E2;
        }

        #stepRange,
        #guidanceRange {
            width: 80%;
        }

        span {
            font-size: 1.1rem;
            font-weight: 600;
            color: #333;
        }

        input[type="button"] {
            padding: 12px 25px;
            background-color: #4A90E2;
            color: #fff;
            font-size: 1.2rem;
            border: none;
            border-radius: 8px;
            cursor: pointer;
            transition: background-color 0.3s ease;
            width: 100%;
            margin-bottom: 20px;
        }

        input[type="button"]:disabled {
            background-color: #ccc;
            cursor: not-allowed;
        }

        input[type="button"]:hover {
            background-color: #357ABD;
        }

        #divModelDl, #divStepProgress {
            display: flex;
            justify-content: space-between;
            align-items: center;
            margin-bottom: 20px;
        }

        #modelDlProgressBar,
        #progressBar {
            width: 80%;
            height: 12px;
            border-radius: 6px;
            background-color: #e0e0e0;
        }

        #modelDlProgressBar::-webkit-progress-bar,
        #progressBar::-webkit-progress-bar {
            border-radius: 6px;
        }

        #modelDlProgressValue, #progressFraction {
            font-size: 1rem;
            font-weight: 600;
            color: #333;
        }

        canvas {
            max-width: 100%;
            max-height: 450px;
            margin-top: 10px;
            border-radius: 8px;
            border: 1px solid #ddd;
        }
    </style>

    <script type="module">
        import ClipTokenizer from './clip_tokenizer.js';
        window.clipTokenizer = new ClipTokenizer();
    </script>
    <script src="./net.js"></script>
</head>
<body>
    <h1 id="wgpuError" style="display: none;">WebGPU is not supported in this browser</h1>
    <h1 id="sdTitle">StableDiffusion powered by <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="color: #4A90E2;">tinygrad</a></h1>
    <a href="https://github.com/tinygrad/tinygrad" target="_blank" style="position: absolute; top: 20px; right: 20px;">
        <img src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg"
             alt="GitHub Logo"
             style="width: 32px; height: 32px;">
    </a>
    
    <div id="mybox">
        <form id="promptForm">
            <input id="promptText" type="text" placeholder="Enter your prompt here" value="a human standing on the surface of mars">

            <label>
                Steps: <span id="stepValue">9</span>
                <input id="stepRange" type="range" min="5" max="20" value="9" step="1">
            </label>

            <label>
                Guidance: <span id="guidanceValue">8.0</span>
                <input id="guidanceRange" type="range" min="3" max="15" value="8.0" step="0.1">
            </label>

            <input id="btnRunNet" type="button" value="Run" disabled>
        </form>

        <div id="divModelDl" style="display: flex; align-items: center; width: 100%; gap: 10px;">
            <span id="modelDlTitle">Downloading model</span>
            <progress id="modelDlProgressBar" value="0" max="100" style="flex-grow: 1;"></progress>
            <span id="modelDlProgressValue"></span>
        </div>

        <div id="divStepProgress" style="display: none; align-items: center; width: 100%; gap: 10px;">
            <progress id="progressBar" value="0" max="100" style="flex-grow: 1;"></progress>
            <span id="progressFraction"></span>
        </div>

        <div id="divStepTime" style="display: none; align-items: center; width: 100%; gap: 10px;">
            <span id="stepTimeValue">0 ms</span>
        </div>
    </div>

    <canvas id="canvas" width="512" height="512"></canvas>

<script>
    let f16decomp = null;

    function initDb() {
        return new Promise((resolve, reject) => {
            let db;
            const request = indexedDB.open('tinydb', 1);
            request.onerror = (event) => {
                console.error('Database error:', event.target.error);
                resolve(null);
            };

            request.onsuccess = (event) => {
                db = event.target.result;
                console.log("Db initialized.");
                resolve(db);
            };

            request.onupgradeneeded = (event) => {
                db = event.target.result;
                if (!db.objectStoreNames.contains('tensors')) {
                    db.createObjectStore('tensors', { keyPath: 'id' });
                }
            };
        });
    }

    function saveTensorToDb(db, id, tensor) {
        return readTensorFromDb(db, id).then((result) => {
            if (!result) {
                new Promise((resolve, reject) => {
                    if (db == null) {
                        resolve(null);
                    }

                    const transaction = db.transaction(['tensors'], 'readwrite');
                    const store = transaction.objectStore('tensors');
                    const request = store.put({ id: id, content: tensor });

                    transaction.onabort = (event) => {
                        console.log("Transaction error while saving tensor: " + event.target.error);
                        resolve(null);
                    };

                    request.onsuccess = () => {
                        console.log('Tensor saved successfully.');
                        resolve();
                    };

                    request.onerror = (event) => {
                        console.error('Tensor save failed:', event.target.error);
                        resolve(null);
                    };
                });
            } else {
                return null;
            }
        }).catch(()=> null);
    }

    function readTensorFromDb(db, id) {
        return new Promise((resolve, reject) => {
            if (db == null) {
                resolve(null);
            }
            
            const transaction = db.transaction(['tensors'], 'readonly');
            const store = transaction.objectStore('tensors');
            const request = store.get(id);

            transaction.onabort = (event) => {
                console.log("Transaction error while reading tensor: " + event.target.error);
                resolve(null);
            };

            request.onsuccess = (event) => {
                const result = event.target.result;
                if (result) {
                    resolve(result);
                } else {
                    resolve(null);
                }
            };

            request.onerror = (event) => {
                console.error('Tensor retrieve failed: ', event.target.error);
                resolve(null);
            };
        });
    }

    window.addEventListener('load', async function() {
        if (!navigator.gpu) {
            document.getElementById("wgpuError").style.display = "block";
            document.getElementById("sdTitle").style.display = "none";
            return;
        }

        let db = await initDb();

        const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
        let labels, nets, safetensorParts;

        const getDevice = async () => {
            const adapter = await navigator.gpu.requestAdapter();
            const requiredLimits = {};
            const maxBufferSizeInSDModel = 1073741824;
            requiredLimits.maxStorageBufferBindingSize = maxBufferSizeInSDModel;
            requiredLimits.maxBufferSize = maxBufferSizeInSDModel;
            
            return await adapter.requestDevice({
                requiredLimits,
                requiredFeatures: ["shader-f16"],
                powerPreference: "high-performance"
            });
        };

        const timer = async (func, label = "") => {
            const start = performance.now();
            const out = await func();
            const delta = (performance.now() - start).toFixed(1)
            console.log(`${delta} ms ${label}`);
            return out;
        }

        const getProgressDlForPart = async (part, progressCallback) => {
            const response = await fetch(part);
            const contentLength = response.headers.get('content-length');
            const total = parseInt(contentLength, 10);

            const res = new Response(new ReadableStream({
                async start(controller) {
                    const reader = response.body.getReader();
                    for (;;) {
                        const { done, value } = await reader.read();
                        if (done) break;
                        progressCallback(part, value.byteLength, total);
                        controller.enqueue(value);
                    }
                    
                    controller.close();
                },
            }));
        
            return res.arrayBuffer();
        };
  
        const getAndDecompressF16Safetensors = async (device, progress) => {
            let totalLoaded = 0;
            let totalSize = 0;
            let partSize = {};

            const getPart = async(key) => {
                let part = await readTensorFromDb(db, key);

                if (part) {
                    console.log(`Cache hit: ${key}`);
                    return Promise.resolve(part.content);
                } else {
                    console.log(`Cache miss: ${key}`);
                    return getProgressDlForPart(`${window.MODEL_BASE_URL}/${key}.safetensors`, progressCallback);
                }
            }

            const progressCallback = (part, loaded, total) => {
                totalLoaded += loaded;

                if (!partSize[part]) {
                    totalSize += total;
                    partSize[part] = true;
                }
                
                progress(totalLoaded, totalSize);
            };

            let netKeys = ["net_part0", "net_part1", "net_part2", "net_part3", "net_textmodel"];
            let buffers = await Promise.all(netKeys.map(key => getPart(key)));

            // Combine everything except for text model, since that's already f32
            const totalLength = buffers.reduce((acc, buffer, index, array) => {
                if (index < 4) {
                    return acc + buffer.byteLength;
                } else {
                    return acc;
                }
                }, 0
            );

            combinedBuffer = new Uint8Array(totalLength);
            let offset = 0;
            buffers.forEach((buffer, index) => {
                saveTensorToDb(db, netKeys[index], new Uint8Array(buffer));
                if (index < 4) {
                    combinedBuffer.set(new Uint8Array(buffer), offset);
                    offset += buffer.byteLength;
                    buffer = null;
                }
            });

            let textModelU8 = new Uint8Array(buffers[4]);
            document.getElementById("modelDlTitle").innerHTML = "Decompressing model";

            const textModelOffset = 3772703308;
            const metadataLength = Number(new DataView(combinedBuffer.buffer).getBigUint64(0, true));
            const metadata = JSON.parse(new TextDecoder("utf8").decode(combinedBuffer.subarray(8, 8 + metadataLength)));

            const allToDecomp = combinedBuffer.byteLength - (8 + metadataLength);
            const decodeChunkSize = 8388480;
            const numChunks = Math.ceil(allToDecomp/decodeChunkSize);

            console.log(allToDecomp + " bytes to decompress");
            console.log("Will be decompressed in " + numChunks+ " chunks");

            let partOffsets = [{start: 0, end: 1131408336}, {start: 1131408336, end: 2227518416}, {start: 2227518416, end: 3308987856}, {start: 3308987856, end: 4265298864}];
            let parts = [];

            for (let offsets of partOffsets) {
                parts.push(new Uint8Array(offsets.end-offsets.start));
            }
            parts[0].set(new Uint8Array(new BigUint64Array([BigInt(metadataLength)]).buffer), 0);
            parts[0].set(combinedBuffer.subarray(8, 8 + metadataLength), 8);
            parts[3].set(textModelU8, textModelOffset+8+metadataLength - partOffsets[3].start);

            let start = Date.now();
            let cursor = 0;

            for (let i = 0; i < numChunks; i++) {
                progress(i, numChunks);
                let chunkStartF16 = 8 + metadataLength + (decodeChunkSize * i);
                let chunkEndF16 = chunkStartF16 + decodeChunkSize;
                let chunk = combinedBuffer.subarray(chunkStartF16, chunkEndF16);
                let uint32Chunk = new Uint32Array(chunk.buffer, chunk.byteOffset, chunk.byteLength / 4);
                let result = await f16decomp(uint32Chunk);
                let resultUint8 = new Uint8Array(result.buffer);
                let chunkStartF32 = 8 + metadataLength + (decodeChunkSize * i * 2);
                let chunkEndF32 = chunkStartF32 + resultUint8.byteLength;
                let offsetInPart = chunkStartF32 - partOffsets[cursor].start;

                if (chunkEndF32 < partOffsets[cursor].end || cursor === parts.length - 1) {
                    parts[cursor].set(resultUint8, offsetInPart);
                } else {
                    let spaceLeftInCurrentPart = partOffsets[cursor].end - chunkStartF32;
                    parts[cursor].set(resultUint8.subarray(0, spaceLeftInCurrentPart), offsetInPart);

                    cursor++;

                    if (cursor < parts.length) {
                        let nextPartOffset = spaceLeftInCurrentPart;
                        let nextPartLength = resultUint8.length - nextPartOffset;
                        parts[cursor].set(resultUint8.subarray(nextPartOffset, nextPartOffset + nextPartLength), 0);
                    }
                }

                resultUint8 = null;
                result = null;
            }

            combinedBuffer = null;

            let end = Date.now();
            console.log("Decoding took: " + ((end - start) / 1000) + " s");
            console.log("Avarage " + ((end - start) / numChunks) + " ms per chunk");

            return parts;
        };

        const loadNet = async () => {
            const modelDlTitle = document.getElementById("modelDlTitle");

            const progress = (loaded, total) => {
                document.getElementById("modelDlProgressBar").value = (loaded/total) * 100
                document.getElementById("modelDlProgressValue").innerHTML = Math.trunc((loaded/total) * 100) + "%"
            }

            const device = await getDevice();
            f16decomp =  await f16tof32().setup(device, safetensorParts),
            safetensorParts = await getAndDecompressF16Safetensors(device, progress);

            modelDlTitle.innerHTML = "Compiling model"

            let models = ["textModel", "diffusor", "decoder"];

            nets = await timer(() => Promise.all([
                textModel().setup(device, safetensorParts),
                diffusor().setup(device, safetensorParts),
                decoder().setup(device, safetensorParts)
            ]).then((loadedModels) => loadedModels.reduce((acc, model, index) => { acc[models[index]] = model; return acc; }, {})), "(compilation)")

            progress(1, 1);

            modelDlTitle.innerHTML = "Model ready"
            setTimeout(() => {
                document.getElementById("modelDlProgressBar").style.display = "none";
                document.getElementById("modelDlProgressValue").style.display = "none";
                document.getElementById("divStepProgress").style.display = "flex";
            }, 1000);
            document.getElementById("btnRunNet").disabled = false;
        }

        function runStableDiffusion(prompt, steps, guidance, showStep) {
            return new Promise(async (resolve, reject) => {
                let context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP(prompt)));
                let unconditional_context = await timer(() => nets["textModel"](clipTokenizer.encodeForCLIP("")));

                let timesteps = [];

                for (let i = 1; i < 1000; i += (1000/steps)) {
                    timesteps.push(i);
                }

                console.log("Timesteps: " + timesteps);

                let alphasCumprod = getWeight(safetensorParts,"alphas_cumprod");
                let alphas = [];

                for (t of timesteps) {
                    alphas.push(alphasCumprod[Math.floor(t)]);
                }

                alphas_prev = [1.0];

                for (let i = 0; i < alphas.length-1; i++) {
                    alphas_prev.push(alphas[i]);
                }

                let inpSize = 4*64*64;
                latent = new Float32Array(inpSize);

                for (let i = 0; i < inpSize; i++) {
                    latent[i] = Math.sqrt(-2.0 * Math.log(Math.random())) * Math.cos(2.0 * Math.PI * Math.random());
                }
                
                for (let i = timesteps.length - 1; i >= 0; i--) {
                    let timestep = new Float32Array([timesteps[i]]);
                    let start = performance.now()
                    let x_prev = await nets["diffusor"](unconditional_context, context, latent, timestep, new Float32Array([alphas[i]]), new Float32Array([alphas_prev[i]]), new Float32Array([guidance]));
                    document.getElementById("divStepTime").style.display = "block";
                    document.getElementById("stepTimeValue").innerText = `${(performance.now() - start).toFixed(1)} ms / step`
                    latent = x_prev;

                    if (showStep != null) {
                        showStep(await nets["decoder"](latent));
                    }

                    document.getElementById("progressBar").value = ((steps - i) / steps) * 100
                    document.getElementById("progressFraction").innerHTML = (steps - i) + "/" + steps
                }
                
                resolve(await timer(() => nets["decoder"](latent)));
            });
        }

        function renderImage(image) {
            let pixels = []
            let pixelCounter = 0

            for (let j = 0; j < 512; j++) {
                for (let k = 0; k < 512; k++) {
                    pixels.push(image[pixelCounter])
                    pixels.push(image[pixelCounter+1])
                    pixels.push(image[pixelCounter+2])
                    pixels.push(255)
                    pixelCounter += 3
                }
            }

            ctx.putImageData(new ImageData(new Uint8ClampedArray(pixels), 512, 512), 0, 0);
        }

        const handleRunNetAndRenderResult = () => {
            document.getElementById("btnRunNet").disabled = true;
            const canvas = document.getElementById("canvas");
            ctx.clearRect(0, 0, canvas.width, canvas.height);

            const prevTitleValue = document.getElementById("modelDlTitle").innerHTML;
            document.getElementById("modelDlTitle").innerHTML = "Running model";

            runStableDiffusion(
                document.getElementById("promptText").value,
                document.getElementById("stepRange").value,
                document.getElementById("guidanceRange").value,
                // Decode at each step
                null
            ).then((image) => {
                renderImage(image);
            }).finally(() => {
                document.getElementById("modelDlTitle").innerHTML = prevTitleValue;
                document.getElementById("btnRunNet").disabled = false;
            });
        };

        document.getElementById("btnRunNet").addEventListener("click", handleRunNetAndRenderResult, false);

        document.getElementById("promptForm").addEventListener("submit", function (event) {
            event.preventDefault();
            if (document.getElementById("btnRunNet").disabled) return;

            handleRunNetAndRenderResult();
        })

        const stepSlider = document.getElementById('stepRange');
        const stepValue = document.getElementById('stepValue');

        stepSlider.addEventListener('input', function() {
            stepValue.textContent = stepSlider.value;
        });

        const guidanceSlider = document.getElementById('guidanceRange');
        const guidanceValue = document.getElementById('guidanceValue');

        guidanceSlider.addEventListener('input', function() {
            guidanceValue.textContent = guidanceSlider.value;
        });

        loadNet();
    });
</script>
</body>
</html>
